# ------------------------------------------------------------------------------
# Plots yield rate for stress and cath tests
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 05_plot_stress_cath_calib.R
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(ggplot2)

u <- modules::use(here("lib", "util.R"))
a <- modules::use(here("lib", "aesthetics.R"))

temp <- here(
  "code", "03_score_diagnostics", "temp",
  "05_plot_stress_cath_calib"
)
if (!dir.exists(temp)) {
  dir.create(temp)
}

# Helpers ----------------------------------------------------------------------
get_mean_outcomes <- function(subsample, outcome, x_var, ...) {
  sub_df <- cohort[cohort[[subsample]], ]
  mean_outcomes <- u$get_mean_outcomes(
    "tested", outcome, x_var, sub_df
  )
  return(mean_outcomes)
}

add_labels <- function(gg, outcome, x_var, subsample, ...) {
  xlabel <- "Percentile of Predicted Risk"
  test_type <- case_when(
    grepl("stress", subsample) ~ "Stress",
    grepl("cath", subsample) ~ "Cath",
    TRUE ~ ""
  )
  ylabel <- case_when(
    outcome == "stent_or_cabg_010_day" ~ glue("Yield of {test_type} Testing"),
    TRUE ~ outcome
  )

  gg <- gg + labs(x = xlabel, y = ylabel)
  return(gg)
}

get_filename <- function(outcome, subsample, x_var, ...) {
  x_var_lab <- case_when(
    grepl("tile_stent_or_cabg", x_var) ~ parse_number(x_var) %>%
      str_pad(3, pad = "0") %>%
      {
        glue("tile_{.}")
      },
    TRUE ~ x_var
  )
  fn <- (file.path(temp, glue("{outcome}__by__{x_var_lab}__for__{subsample}.png")))
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
overnight_lab <- ""
cohort <- readRDS(glue(paths$analysis$test_cohort)) %>%
  filter(!exclude) %>%
  mutate(cath_first_010_day = cath_010_day & !stress_010_day)

# Outcome Config ---------------------------------------------------------------
message("Preparing config table...")
config <- crossing(
  outcome = c("stent_or_cabg_010_day"),
  x_var = c("tile_stent_or_cabg_010_tested"),
  subsample = c("stress_010_day", "cath_first_010_day")
)

# Plotting Means ---------------------------------------------------------------
message("Plotting calibration curves...")
plots <- config %>%
  mutate(mean_outcomes = pmap(., get_mean_outcomes)) %>%
  mutate(gg = pmap(., a$tile_plot)) %>%
  mutate(plot = pmap(., add_labels)) %>%
  mutate(filename = pmap(., get_filename))

# Save -------------------------------------------------------------------------
message("Saving...")
saves <- plots %>%
  select(plot, filename) %>%
  pmap(ggsave, width = 10, height = 7, units = "in")

message("Done.")
